#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import os.path
import time

import numpy as np
import tensorflow as tf

from PIL import Image
from PIL import ImageDraw
from PIL import ImageFont
import cv2

from detection import vehicle_detection
from detection import load_image_into_numpy_array
from classification import vehicle_classification

from flask import Flask, request

import uuid

FLAGS = tf.app.flags.FLAGS

tf.app.flags.DEFINE_string('class_model_path',
                           './freezed/classify/freezed.pb', """model for classification""")
tf.app.flags.DEFINE_string('detection_model_path',
                           './freezed/detection/frozen_inference_graph.pb', """model for detection""")
tf.app.flags.DEFINE_string('label_file', './labels.txt', """classification labels""")
tf.app.flags.DEFINE_string('upload_folder', './uploads', """folder for save upload images""")
tf.app.flags.DEFINE_string('output_folder', './outputs', """folder for save upload images""")
tf.app.flags.DEFINE_integer('port', '5001', """server with port; if no port, use default port 80""")
tf.app.flags.DEFINE_boolean('debug', False, """if debug""")
PATH_TO_LABELS = os.path.join('./labels_items.txt')
UPLOAD_FOLDER = FLAGS.upload_folder

ALLOWED_EXTENSIONS = set(['jpg', 'JPG', 'jpeg', 'JPEG', 'png', 'PNG'])

app = Flask(__name__)
app._static_folder = FLAGS.output_folder

def allowed_files(filename):
    return '.' in filename and \
           filename.rsplit('.', 1)[-1] in ALLOWED_EXTENSIONS

def rename_filename(old_filename):
    basename = os.path.basename(old_filename)
    (name, ext) = os.path.splitext(basename)
    new_filename = str(uuid.uuid1()) + ext
    return new_filename

def create_category_index(label_file):
    with open(label_file) as f:
        category_index = {}
        for i in f.readlines():
            lns = i.strip('\n')
            for ln in lns.split(','):
                index_and_name = ln.strip().split(':')
                category_index[int(index_and_name[0])] = {
                    'id' : int(index_and_name[0]), 'name' : index_and_name[1]}
    return category_index

def inference(filename, od_model_path, cls_model_path):
    result_detection = vehicle_detection(filename, 1, od_model_path, PATH_TO_LABELS)
    scores = np.squeeze(result_detection[0])
    classes = np.squeeze(result_detection[1])
    boxes = np.squeeze(result_detection[2])
    print(boxes.size)
    image = Image.open(filename)
    (image_np, im_height, im_width) = load_image_into_numpy_array(image)
    count = 0
    for i in range(len(boxes)):
        if scores[i] > 0.2:
            bbox = boxes[i]
            y1 = bbox[0] * im_height
            x1 = bbox[1] * im_width
            y2 = bbox[2] * im_height
            x2 = bbox[3] * im_width
            img = Image.fromarray(image_np, 'RGB')
            img = img.crop([img.size[0]*bbox[1]*0.9,img.size[1]*bbox[0]*0.9,img.size[0]*bbox[3]*1.1,img.size[1]*bbox[2]*1.1])
            img.save('./uploads/medimun.jpg')
            img_path = os.path.join('./uploads/medimun.jpg')
            detection_graph = tf.Graph()
            result_classification = vehicle_classification(img_path, cls_model_path, FLAGS.label_file, 1)
            predictions = result_classification[0]
            class_index = np.where(predictions == np.max(predictions))[0][0]
            prob = np.max(predictions) * 100
            category_index = create_category_index(label_file=FLAGS.label_file)
            class_name = category_index[class_index]['name']

            if i == 0:
                im = cv2.imread(filename, cv2.COLOR_BGR2RGB)
            else:
                im = cv2.imread(os.path.join(FLAGS.output_folder, os.path.basename(filename)), cv2.COLOR_BGR2RGB)
            pil_im = Image.fromarray(im)
            draw = ImageDraw.Draw(pil_im)
            font = ImageFont.truetype('./simhei.ttf', 16, encoding='utf-8')

            if prob > 15:
                disp_str = class_name + ': ' + ('%s' % str(int(prob))) + '%'
            else:
                disp_str = '未识别该车型'

            len_disp_str = len(disp_str.encode('gb2312'))
            draw.rectangle((x1, y1, x1 + len_disp_str * 8, y1 + 16),
                           fill=(211, 211, 211), outline=None)

            draw.text((x1, y1), disp_str, (255, 0, 255), font=font)

            draw.line((x1, y1, x2, y1), (0, 255, 0), width=3)
            draw.line((x2, y1, x2, y2), (0, 255, 0), width=3)
            draw.line((x2, y2, x1, y2), (0, 255, 0), width=3)
            draw.line((x1, y2, x1, y1), (0, 255, 0), width=3)

            im = np.array(pil_im)
            cv2.imwrite(os.path.join(FLAGS.output_folder, os.path.basename(filename)), im)

            count = count + 1

        else:
            continue

    if count == 0:
        im = cv2.imread(filename, cv2.COLOR_BGR2RGB)
        pil_im = Image.fromarray(im)
        draw = ImageDraw.Draw(pil_im)
        font = ImageFont.truetype('./simhei.ttf', 16, encoding='utf-8')
        disp_str = '未检测到车辆'
        len_disp_str = len(disp_str.encode('gb2312'))
        draw.rectangle((0, 0, 0 + len_disp_str * 8, 0 + 16),
                       fill=(211, 211, 211), outline=None)
        draw.text((0, 0), disp_str, (255, 0, 255), font=font)
        im = np.array(pil_im)
        cv2.imwrite(os.path.join(FLAGS.output_folder, os.path.basename(filename)), im)
        prob = 0
        class_name = 'There\'s no car!'
        new_url = '/static/%s' % os.path.basename(filename)
        image_tag = '<img src="%s"></img><p>'
        new_tag = image_tag % new_url

        format_string = '%s (score: %.1f%%)' % (class_name, prob)

        ret_string = new_tag + format_string + '<BR>'
        return ret_string
    else:
        new_url = '/static/%s' % os.path.basename(filename)
        image_tag = '<img src="%s"></img><p>'
        new_tag = image_tag % new_url
        format_string = 'there are/is %d car(s)' % (count)

        ret_string = new_tag + format_string + '<BR>'
        return ret_string

@app.route("/", methods=['GET', 'POST'])
def root():
    result = """
                <!doctype html>
                <title>车辆检测及型号识别</title>
                <h1>上传待检测图片</h1>
                <form action="" method=post enctype=multipart/form-data>
                <p><input type=file name=file value='选择图片'>
                    <input type=submit value='上传'>
                </form>
                <p>%s</p>
                """ % "<br>"
    if request.method == 'POST':
        file = request.files['file']
        old_file_name = file.filename
        if file and allowed_files(old_file_name):
            filename = rename_filename(old_file_name)
            file_path = os.path.join(UPLOAD_FOLDER, filename)
            file.save(file_path)
            type_name = 'N/A'
            print('file saved to %s' % file_path)
            start_time = time.time()
            out_html = inference(file_path,
                                 od_model_path=FLAGS.detection_model_path,
                                 cls_model_path=FLAGS.class_model_path)
            duration = time.time() - start_time
            print('duration:[%.0fms]' % (duration * 1000))
            return result + out_html
    return result

if __name__ == "__main__":
    print('listening on port %d' % FLAGS.port)
    sess = tf.Session()
    app.sess = sess
    app.run(host='127.0.0.1', port=FLAGS.port, debug=FLAGS.debug, threaded=True)